{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from transformers import BertTokenizer, BertForSequenceClassification, AdamW, BertConfig\n",
    "\n",
    "# 自定义数据集类\n",
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, data_path, tokenizer):\n",
    "        self.sentences, self.labels = self.load_data(data_path)\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "    def load_data(self, data_path):\n",
    "        sentences = []\n",
    "        labels = []\n",
    "        with open(data_path, 'r', encoding='utf-8') as file:\n",
    "            for line in file:\n",
    "                sentence, label = line.strip().split('&')\n",
    "                sentences.append(sentence)\n",
    "                labels.append(int(label))\n",
    "        return sentences, labels\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.sentences)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sentence = self.sentences[idx]\n",
    "        label = self.labels[idx]\n",
    "\n",
    "        # 使用BERT的tokenizer对句子进行编码\n",
    "        encoding = self.tokenizer.encode_plus(\n",
    "            sentence,\n",
    "            add_special_tokens=True,\n",
    "            padding='max_length',\n",
    "            truncation=True,\n",
    "            max_length=512,\n",
    "            return_tensors='pt'\n",
    "        )\n",
    "\n",
    "        input_ids = encoding['input_ids'].squeeze()\n",
    "        attention_mask = encoding['attention_mask'].squeeze()\n",
    "\n",
    "        return {\n",
    "            'input_ids': input_ids,\n",
    "            'attention_mask': attention_mask,\n",
    "            'label': torch.tensor(label)\n",
    "        }\n",
    "\n",
    "\n",
    "# 设置训练参数\n",
    "data_path = './data.txt'\n",
    "model_name = 'hfl/chinese-roberta-wwm-ext'\n",
    "batch_size = 32\n",
    "label_cnt = 11\n",
    "learning_rate = 1e-4\n",
    "num_epochs = 30\n",
    "\n",
    "# 加载预训练的BERT模型和tokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
    "config = BertConfig.from_pretrained(model_name)\n",
    "config.num_hidden_layers = 2\n",
    "config.num_labels = label_cnt   # 将 num_labels 添加到 config 中\n",
    "model = BertForSequenceClassification.from_pretrained(model_name, config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\12236\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# 加载数据集\n",
    "dataset = CustomDataset(data_path, tokenizer)\n",
    "\n",
    "# 划分训练集和验证集\n",
    "train_size = int(0.8 * len(dataset))\n",
    "valid_size = len(dataset) - train_size\n",
    "train_dataset, valid_dataset = torch.utils.data.random_split(\n",
    "    dataset, [train_size, valid_size])\n",
    "\n",
    "# 创建数据加载器\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "# 定义优化器和损失函数\n",
    "optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
    "# scheduler = StepLR(optimizer, step_size=30, gamma=0.1)\n",
    "loss_fn = torch.nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "if torch.cuda.device_count() > 1:\n",
    "    model = torch.DataParallel(model, device_ids=[0, 1, 2])\n",
    "\n",
    "model.to(device)\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30 - Loss: 2.4697 - Valid Loss: 2.4001 - Accuracy: 0.1515 - Valid Accuracy: 0.1176\n",
      "Epoch 2/30 - Loss: 2.3332 - Valid Loss: 2.4109 - Accuracy: 0.1818 - Valid Accuracy: 0.1176\n",
      "Epoch 3/30 - Loss: 2.3231 - Valid Loss: 2.4203 - Accuracy: 0.2273 - Valid Accuracy: 0.1176\n",
      "Epoch 4/30 - Loss: 2.3509 - Valid Loss: 2.4029 - Accuracy: 0.2273 - Valid Accuracy: 0.1176\n",
      "Epoch 5/30 - Loss: 2.3683 - Valid Loss: 2.3788 - Accuracy: 0.2273 - Valid Accuracy: 0.1176\n",
      "Epoch 6/30 - Loss: 2.2438 - Valid Loss: 2.3475 - Accuracy: 0.2273 - Valid Accuracy: 0.1176\n",
      "Epoch 7/30 - Loss: 2.2199 - Valid Loss: 2.3103 - Accuracy: 0.2424 - Valid Accuracy: 0.1765\n",
      "Epoch 8/30 - Loss: 2.0366 - Valid Loss: 2.2571 - Accuracy: 0.3788 - Valid Accuracy: 0.2941\n",
      "Epoch 9/30 - Loss: 1.6798 - Valid Loss: 2.2387 - Accuracy: 0.4091 - Valid Accuracy: 0.2941\n",
      "Epoch 10/30 - Loss: 1.8996 - Valid Loss: 2.1302 - Accuracy: 0.4697 - Valid Accuracy: 0.2941\n",
      "Epoch 11/30 - Loss: 1.7134 - Valid Loss: 1.9543 - Accuracy: 0.5455 - Valid Accuracy: 0.3529\n",
      "Epoch 12/30 - Loss: 1.3971 - Valid Loss: 1.9282 - Accuracy: 0.5909 - Valid Accuracy: 0.3529\n",
      "Epoch 13/30 - Loss: 1.2754 - Valid Loss: 1.7405 - Accuracy: 0.5758 - Valid Accuracy: 0.4706\n",
      "Epoch 14/30 - Loss: 0.9974 - Valid Loss: 1.6860 - Accuracy: 0.7273 - Valid Accuracy: 0.3529\n",
      "Epoch 15/30 - Loss: 0.7665 - Valid Loss: 1.7870 - Accuracy: 0.7576 - Valid Accuracy: 0.2941\n",
      "Epoch 16/30 - Loss: 0.9351 - Valid Loss: 1.8142 - Accuracy: 0.7576 - Valid Accuracy: 0.3529\n",
      "Epoch 17/30 - Loss: 0.7146 - Valid Loss: 1.8350 - Accuracy: 0.7273 - Valid Accuracy: 0.4118\n",
      "Epoch 18/30 - Loss: 0.5400 - Valid Loss: 1.6096 - Accuracy: 0.8182 - Valid Accuracy: 0.4118\n",
      "Epoch 19/30 - Loss: 0.5634 - Valid Loss: 1.5151 - Accuracy: 0.8788 - Valid Accuracy: 0.4706\n",
      "Epoch 20/30 - Loss: 0.4714 - Valid Loss: 1.5077 - Accuracy: 0.8333 - Valid Accuracy: 0.4706\n",
      "Epoch 21/30 - Loss: 0.4256 - Valid Loss: 1.5432 - Accuracy: 0.8333 - Valid Accuracy: 0.5294\n",
      "Epoch 22/30 - Loss: 0.4206 - Valid Loss: 1.5217 - Accuracy: 0.8788 - Valid Accuracy: 0.5294\n",
      "Epoch 23/30 - Loss: 0.3455 - Valid Loss: 1.6768 - Accuracy: 0.8636 - Valid Accuracy: 0.4706\n",
      "Epoch 24/30 - Loss: 0.2809 - Valid Loss: 1.6367 - Accuracy: 0.8939 - Valid Accuracy: 0.5294\n",
      "Epoch 25/30 - Loss: 0.2649 - Valid Loss: 1.5520 - Accuracy: 0.9394 - Valid Accuracy: 0.5294\n",
      "Epoch 26/30 - Loss: 0.4262 - Valid Loss: 1.3868 - Accuracy: 0.9242 - Valid Accuracy: 0.5882\n",
      "Epoch 27/30 - Loss: 0.3861 - Valid Loss: 1.4102 - Accuracy: 0.9242 - Valid Accuracy: 0.5294\n",
      "Epoch 28/30 - Loss: 0.2784 - Valid Loss: 1.6377 - Accuracy: 0.9091 - Valid Accuracy: 0.5882\n",
      "Epoch 29/30 - Loss: 0.3482 - Valid Loss: 1.9709 - Accuracy: 0.9394 - Valid Accuracy: 0.4706\n",
      "Epoch 30/30 - Loss: 0.3298 - Valid Loss: 2.0915 - Accuracy: 0.9848 - Valid Accuracy: 0.3529\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    correct_predictions = 0\n",
    "    total_predictions = 0\n",
    "\n",
    "    for batch in train_loader:\n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device)\n",
    "        labels = batch['label'].to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.item()\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        _, predicted_labels = torch.max(outputs.logits, dim=1)\n",
    "        correct_predictions += (predicted_labels == labels).sum().item()\n",
    "        total_predictions += labels.size(0)\n",
    "    \n",
    "    # 在验证集上评估模型\n",
    "    model.eval()\n",
    "    valid_loss = 0\n",
    "    correct_valid_predictions = 0\n",
    "    total_valid_predictions = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch in valid_loader:\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['label'].to(device)\n",
    "\n",
    "            outputs = model(input_ids, attention_mask=attention_mask)\n",
    "            logits = outputs.logits\n",
    "            _, predicted_labels = torch.max(logits, dim=1)\n",
    "\n",
    "            valid_loss += loss_fn(logits, labels).item()\n",
    "            correct_valid_predictions += (predicted_labels == labels).sum().item()\n",
    "            total_valid_predictions += labels.size(0)\n",
    "\n",
    "    epoch_loss = total_loss / len(train_loader)\n",
    "    epoch_valid_loss = valid_loss / len(valid_loader)\n",
    "    accuracy = correct_predictions / total_predictions\n",
    "    valid_accuracy = correct_valid_predictions / total_valid_predictions\n",
    "    # scheduler.step()\n",
    "\n",
    "    print(f'Epoch {epoch + 1}/{num_epochs} - Loss: {epoch_loss:.4f} - Valid Loss: {epoch_valid_loss:.4f} - Accuracy: {accuracy:.4f} - Valid Accuracy: {valid_accuracy:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8433734939759037\n"
     ]
    }
   ],
   "source": [
    "# 计算整体准确率\n",
    "v_loader = DataLoader(dataset)\n",
    "model.eval()\n",
    "correct_predictions = 0\n",
    "total_predictions = 0\n",
    "with torch.no_grad():\n",
    "        for batch in v_loader:\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['label'].to(device)\n",
    "\n",
    "            outputs = model(input_ids, attention_mask=attention_mask)\n",
    "            logits = outputs.logits\n",
    "            _, predicted_labels = torch.max(logits, dim=1)\n",
    "\n",
    "            correct_predictions += (predicted_labels == labels).sum().item()\n",
    "            total_predictions += labels.size(0)\n",
    "\n",
    "valid_accuracy = correct_predictions / total_predictions\n",
    "\n",
    "print(valid_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存模型\n",
    "torch.save(model, 'job.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[7], line 50\u001b[0m\n\u001b[0;32m     46\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mdict\u001b[39m(\u001b[39msorted\u001b[39m(predicted_labels\u001b[39m.\u001b[39mitems(), key\u001b[39m=\u001b[39m\u001b[39mlambda\u001b[39;00m x: x[\u001b[39m1\u001b[39m], reverse\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m))\n\u001b[0;32m     49\u001b[0m \u001b[39m# 加载训练好的模型参数\u001b[39;00m\n\u001b[1;32m---> 50\u001b[0m model\u001b[39m.\u001b[39;49mload_state_dict(torch\u001b[39m.\u001b[39;49mload(\u001b[39m'\u001b[39;49m\u001b[39mjob.pt\u001b[39;49m\u001b[39m'\u001b[39;49m))\n\u001b[0;32m     52\u001b[0m \u001b[39m# 设置模型为评估模式\u001b[39;00m\n\u001b[0;32m     53\u001b[0m model\u001b[39m.\u001b[39meval()\n",
      "File \u001b[1;32mc:\\Users\\12236\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:2009\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m   1986\u001b[0m \u001b[39m\u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\"\"Copies parameters and buffers from :attr:`state_dict` into\u001b[39;00m\n\u001b[0;32m   1987\u001b[0m \u001b[39mthis module and its descendants. If :attr:`strict` is ``True``, then\u001b[39;00m\n\u001b[0;32m   1988\u001b[0m \u001b[39mthe keys of :attr:`state_dict` must exactly match the keys returned\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   2006\u001b[0m \u001b[39m    ``RuntimeError``.\u001b[39;00m\n\u001b[0;32m   2007\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m   2008\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(state_dict, Mapping):\n\u001b[1;32m-> 2009\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mExpected state_dict to be dict-like, got \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\u001b[39mtype\u001b[39m(state_dict)))\n\u001b[0;32m   2011\u001b[0m missing_keys: List[\u001b[39mstr\u001b[39m] \u001b[39m=\u001b[39m []\n\u001b[0;32m   2012\u001b[0m unexpected_keys: List[\u001b[39mstr\u001b[39m] \u001b[39m=\u001b[39m []\n",
      "\u001b[1;31mTypeError\u001b[0m: Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>."
     ]
    }
   ],
   "source": [
    "label = {\n",
    "    0: '其他',\n",
    "    1: '产品运营',\n",
    "    2: '平面设计师',\n",
    "    3: '财务',\n",
    "    4: '市场营销',\n",
    "    5: '项目主管',\n",
    "    6: '开发工程师',\n",
    "    7: '文员',\n",
    "    8: '电商运营',\n",
    "    9: '人力资源管理',\n",
    "    10: '风控专员'\n",
    "}\n",
    "\n",
    "\n",
    "def job_predict(sentence, model, tokenizer):\n",
    "    # 对输入句子进行编码\n",
    "    encoding = tokenizer.encode_plus(\n",
    "        sentence,\n",
    "        add_special_tokens=True,\n",
    "        padding='max_length',\n",
    "        truncation=True,\n",
    "        max_length=512,\n",
    "        return_tensors='pt'\n",
    "    )\n",
    "\n",
    "    input_ids = encoding['input_ids'].squeeze()\n",
    "    attention_mask = encoding['attention_mask'].squeeze()\n",
    "\n",
    "    # 在模型中进行前向传播\n",
    "    with torch.no_grad():\n",
    "        outputs = model(input_ids.unsqueeze(\n",
    "            0), attention_mask=attention_mask.unsqueeze(0))\n",
    "        logits = outputs.logits\n",
    "\n",
    "    # 获取预测的标签和对应的预测概率值\n",
    "    predicted_probs = torch.softmax(logits, dim=1)\n",
    "    # predicted_label = torch.argmax(predicted_probs, dim=1)\n",
    "\n",
    "    predicted_labels = {}\n",
    "    for index, value in enumerate(predicted_probs.squeeze().tolist()):\n",
    "        # 设置概率阈值 超过该阈值的可以作为候选项 此处 0.1 较合理\n",
    "        if value >= 0.1:\n",
    "            predicted_labels[index] = value\n",
    "    \n",
    "    return dict(sorted(predicted_labels.items(), key=lambda x: x[1], reverse=True))\n",
    "\n",
    "\n",
    "# 加载训练好的模型参数\n",
    "model.load_state_dict(torch.load('job.pt'))\n",
    "\n",
    "# 设置模型为评估模式\n",
    "model.eval()\n",
    "\n",
    "# 测试例子\n",
    "input_sentence = '质管部部长助理.负责票证车间生产质量检查'\n",
    "predicted_labels = job_predict(\n",
    "    input_sentence, model, tokenizer)\n",
    "\n",
    "print(f'输入句子: {input_sentence}')\n",
    "\n",
    "for key,value in predicted_labels.items():\n",
    "    print(f'{label[key]}  --  {value}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
